### gadget_snapshot.py

"""Classes and functions for accessing Gadget data."""

import numpy as npy
import matplotlib.pyplot as plt
import struct
import os
import inspect


IDOFF = 200000000  # material id offset


class GadgetHeader:
    """Class for Gadget snapshot header.

        npart - numbers of the six types of particle, SPH first
        t     - time (s) 
    """	
    def __init__(self, t=0.0, nfiles=1, ent=1):
        self.npart = npy.zeros(6).astype(int)
        self.mass = npy.zeros(6)
        self.time = t
        self.redshift = 0.0
        self.flag_sfr = 0
        self.flag_feedbacktp = 0
        self.npartTotal = npy.zeros(6).astype(int)
        self.flag_cooling = 0
        self.num_files = nfiles
        self.BoxSize = 0.0
        self.Omega0 = 0.0
        self.OmegaLambda = 0.0
        self.HubbleParam = 1.0
        self.flag_stellarage = 0
        self.flag_metals = 0
        self.nallhw = npy.zeros(6).astype(int)
        self.flag_entr_ics = ent



class Snapshot:
    """Gadget snapshot class
    
       Includes header and gas particle data, with functions for
       reading and writing snapshots.
       
       load()  -- load Gadget snapshot data
       write() -- save snapshot
       
       
       header - gadget header
       N    - number of SPH particles
       id   - particle ID
       m    - particle mass (g)
       pos  - position vector
       vel  - velocity vector
       x,y,z - cartesian coordinates (cm)
       vx,vy,vz - cartesian velocities (cm/s)
       S    - specific entropy (erg/K/g)
       rho  - density (g/cm^3)
       hsml - smoothing length (cm)
       pot  - gravitational potential (erg/g)
       P    - pressure (GPa)
       T    - temperature (K)
       U    - specific internal energy (erg/g)
       cs   - sound speed (cm/s)
    """
    def __init__(self):
        self.header = GadgetHeader()
        self.N = 0  # number of SPH particles
        self.pos = npy.zeros(3)  # position vector
        self.vel = npy.zeros(3)  # velocity vector
        self.id = 0  # particle ID
        self.m = 0  # mass
        self.S = 0  # entropy
        self.rho = 0  # density
        self.hsml = 0  # smoothing length
        self.pot = 0  # potential
        #Thermo extension
        self.P = 0  # pressure
        self.T = 0  # temperature
        self.U = 0  # internal energy
        self.cs = 0  # sound speed


    def load(self, fname, thermo=True):
        """
            Read gadget snapshot.
            
            set thermo=False for regular gadget snapshots.
        """

        f = open(fname, 'rb')

        struct.unpack('i', f.read(4))  #SKIP

        #HEADER
        self.header.npart = npy.array(struct.unpack('iiiiii', f.read(24)))
        self.header.mass = npy.array(struct.unpack('dddddd', f.read(48)))
        (self.header.time, self.header.redshift, self.header.flag_sfr,
          self.header.flag_feedbacktp) = struct.unpack('ddii', f.read(24))
        self.header.npartTotal = npy.array(struct.unpack('iiiiii', f.read(24)))
        (self.header.flag_cooling, self.header.num_files, self.header.BoxSize,
          self.header.Omega0, self.header.OmegaLambda, self.header.HubbleParam,
          self.header.flag_stellarage,
          self.header.flag_metals) = struct.unpack('iiddddii', f.read(48))
        self.header.nallhw = npy.array(struct.unpack('iiiiii', f.read(24)))
        (self.header.flag_entr_ics,) = struct.unpack('i', f.read(4))
        struct.unpack('60x', f.read(60))

        struct.unpack('i', f.read(4))  #SKIP

        if self.header.num_files != 1:
            print("WARNING! Number of files:", self.header.num_files,
                   ", not currently supported.\n")

        self.N = self.header.npart[0]

        count = str(self.N)
        count3 = str(3*self.N)

        #PARTICLE DATA
        struct.unpack('i', f.read(4))  #SKIP
        self.pos = struct.unpack(count3 + 'f', f.read(3*self.N*4))
        struct.unpack('i', f.read(4))  #SKIP

        struct.unpack('i', f.read(4))  #SKIP
        self.vel = struct.unpack(count3 + 'f', f.read(3*self.N*4))
        struct.unpack('i', f.read(4))  #SKIP

        struct.unpack('i', f.read(4))  #SKIP
        self.id = npy.array(struct.unpack(count + 'i', f.read(self.N*4)))
        struct.unpack('i', f.read(4))  #SKIP

        struct.unpack('i', f.read(4))  #SKIP
        self.m = npy.array(struct.unpack(count + 'f', f.read(self.N*4)))
        struct.unpack('i', f.read(4))  #SKIP

        struct.unpack('i', f.read(4))  #SKIP
        self.S = npy.array(struct.unpack(count + 'f', f.read(self.N*4)))
        struct.unpack('i', f.read(4))  #SKIP

        struct.unpack('i', f.read(4))  #SKIP
        self.rho = npy.array(struct.unpack(count + 'f', f.read(self.N*4)))
        struct.unpack('i', f.read(4))  #SKIP

        struct.unpack('i', f.read(4))  #SKIP
        self.hsml = npy.array(struct.unpack(count + 'f', f.read(self.N*4)))
        struct.unpack('i', f.read(4))  #SKIP

        struct.unpack('i', f.read(4))  #SKIP
        self.pot = npy.array(struct.unpack(count + 'f', f.read(self.N*4)))
        struct.unpack('i', f.read(4))  #SKIP

        if thermo:
            struct.unpack('i', f.read(4))  #SKIP
            self.P = npy.array(struct.unpack(count + 'f', f.read(self.N*4)))
            struct.unpack('i', f.read(4))  #SKIP
            
            struct.unpack('i', f.read(4))  #SKIP
            self.T = npy.array(struct.unpack(count + 'f', f.read(self.N*4)))
            struct.unpack('i', f.read(4))  #SKIP
            
            struct.unpack('i', f.read(4))  #SKIP
            self.U = npy.array(struct.unpack(count+'f', f.read(self.N*4)))
            struct.unpack('i', f.read(4))  #SKIP
            
            struct.unpack('i', f.read(4))  #SKIP
            self.cs = npy.array(struct.unpack(count+'f', f.read(self.N*4)))
            struct.unpack('i', f.read(4))  #SKIP
            

        #REARRANGE
        self.pos = npy.array(self.pos).reshape((self.N, 3))
        self.x = self.pos.T[0]
        self.y = self.pos.T[1]
        self.z = self.pos.T[2]

        self.vel = npy.array(self.vel).reshape((self.N, 3))
        self.vx = self.vel.T[0]
        self.vy = self.vel.T[1]
        self.vz = self.vel.T[2]

        f.close()



    def write(self, fname):

        f = open(fname, 'wb')

        f.write(struct.pack('i', 256))  #SKIP

        #HEADER
        f.write(struct.pack('iiiiii', *self.header.npart))
        f.write(struct.pack('dddddd', *self.header.mass))
        f.write(struct.pack('ddii', self.header.time, self.header.redshift,
                     self.header.flag_sfr, self.header.flag_feedbacktp))
        f.write(struct.pack('iiiiii', *self.header.npartTotal))
        f.write(struct.pack('iiddddii', self.header.flag_cooling,
                     self.header.num_files,self.header.BoxSize,
                     self.header.Omega0,self.header.OmegaLambda,
                     self.header.HubbleParam,self.header.flag_stellarage,
                     self.header.flag_metals))
        f.write(struct.pack('iiiiii', *self.header.nallhw))
        f.write(struct.pack('i', self.header.flag_entr_ics))
        f.write(struct.pack('60x'))

        f.write(struct.pack('i', 256))  #SKIP

        if self.header.num_files != 1:
            print("WARNING! Number of files:", self.header.num_files,
                   ", not currently supported.\n")

        count = str(self.N)
        count3 = str(3*self.N)


        #PARTICLE DATA
        f.write(struct.pack('i', self.N))  #SKIP
        f.write(struct.pack(count3 + 'f',
                             *npy.array(self.pos).reshape((3*self.N))))
        f.write(struct.pack('i', self.N))  #SKIP

        f.write(struct.pack('i', self.N))  #SKIP
        f.write(struct.pack(count3 + 'f',
                             *npy.array(self.vel).reshape((3*self.N))))
        f.write(struct.pack('i', self.N))  #SKIP

        f.write(struct.pack('i', self.N))  #SKIP
        f.write(struct.pack(count + 'i', *self.id))
        f.write(struct.pack('i', self.N))  #SKIP

        f.write(struct.pack('i', self.N))  #SKIP
        f.write(struct.pack(count + 'f', *self.m))
        f.write(struct.pack('i', self.N))  #SKIP

        f.write(struct.pack('i', self.N))  #SKIP
        f.write(struct.pack(count + 'f', *self.S))
        f.write(struct.pack('i', self.N))  #SKIP

        f.write(struct.pack('i', self.N))  #SKIP
        f.write(struct.pack(count + 'f', *self.rho))
        f.write(struct.pack('i', self.N))  #SKIP

        f.write(struct.pack('i', self.N))  #SKIP
        f.write(struct.pack(count + 'f', *self.hsml))
        f.write(struct.pack('i', self.N))  #SKIP
        
        f.write(struct.pack('i', self.N))  #SKIP
        f.write(struct.pack(count + 'f', *self.pot))
        f.write(struct.pack('i', self.N))  #SKIP


        f.close()





if __name__ == "__main__":
    import sys

    file = 'snapshot_000_long'
    if len(sys.argv) > 1:
        file = sys.argv[1]

    snap0 = Snapshot()
    snap0.load(file)

    print(snap0.N)   # print number particles
    print((snap0.U*snap0.m).sum()/1.e7)   # print total internal energy
    
    